# -*- coding: utf-8 -*-

from tvem.models import TVAE

from matplotlib import pyplot as plt
from matplotlib.animation import FuncAnimation
from matplotlib.gridspec import GridSpec
import matplotlib
import numpy as np
import argparse
import h5py
import torch as to
import tvem


def parse_args():
    parser = argparse.ArgumentParser(description="Plot datapoints and learned parameters of a TVAE")
    parser.add_argument(
        "--data",
        help="data file. assumed to be in the format produced by gen_clustered_data.py",
        required=True,
    )
    parser.add_argument(
        "--train-output",
        help="learned parameters. assumed to be in the format produced by the TVEM framework",
        required=True,
    )
    parser.add_argument(
        "--output-dir", help="directory where the output figures should be saved", required=True
    )
    parser.add_argument(
        "--gif-step",
        help="a gif frame is produced every `gif_step` epochs",
        type=int,
        required=True,
    )
    return parser.parse_args()


def plot_means(fig, params, result_idx=None):
    """Plot bars.

    params: a dict-like object with keys 'sigma2', 'W', 'b', each of which is expected to be a
            collection of such parameters ordered by iteration step.
    result_idx: which set of parameters to plot, out of the ones in the collection.
               By default, plot the last values.
    """
    gs = GridSpec(nrows=4, ncols=8)

    result_idx = -1 if result_idx is None else result_idx
    n_layers = len(tuple(k for k in params.keys() if k.startswith("W_")))
    W = tuple(to.tensor(params[f"W_{l}"][result_idx]) for l in range(n_layers))
    b = tuple(to.tensor(params[f"b_{l}"][result_idx]) for l in range(n_layers))
    sigma2 = float(params["sigma2"][result_idx])
    H0 = W[0].shape[0]

    m = TVAE(W_init=W, b_init=b, sigma2_init=sigma2)

    singletons = to.eye(H0).to(tvem.get_device())

    means = m.forward(singletons).detach().cpu()
    D = W[-1].shape[-1]
    assert means.shape == (H0, D)
    assert int(np.sqrt(D)) ** 2 == int(np.sqrt(D) ** 2), "D is not a perfect square"
    R = int(np.sqrt(D))

    vmax = means.abs().max().item()

    bars_axes = []

    for h in range(H0):
        ax = fig.add_subplot(gs[2 + h // 4, h % 4])
        bars_axes.append(ax)
        im = ax.imshow(means[h].reshape(R, R), vmin=-vmax, vmax=vmax, cmap="jet")
        plt.axis("equal")
        plt.axis("off")
        
    plt.colorbar(im, ax=bars_axes)
    bars_axes[2].set_title("Learnt means for singleton states")


def plot_weights(fig, theta):
    gs = GridSpec(nrows=4, ncols=6)

    if "W_1" in theta:
        W_top = theta["W_0"][-1]
        b_top = theta["b_0"][-1]

        vmax = np.max([np.abs(W_top).max(), np.abs(b_top).max()])

        # top weights
        ax_w = fig.add_subplot(gs[0,0])
        ax_w.imshow(W_top, cmap="jet", vmin=-vmax, vmax=vmax)
        ax_w.set_title("Top layer weights")
        plt.axis("equal")
        plt.axis("off")

        # top bias
        ax_b = fig.add_subplot(gs[1,0])
        im = ax_b.imshow(np.atleast_2d(b_top), cmap="jet", vmin=-vmax, vmax=vmax)
        ax_b.set_title("Top layer bias")
        plt.axis("equal")
        plt.axis("off")

        plt.colorbar(im, ax=(ax_w, ax_b))

    # bottom weights
    n_layers = len(tuple(k for k in theta.keys() if k.startswith("W_")))
    W_bottom = theta[f"W_{n_layers - 1}"][-1]
    b_bottom = theta[f"b_{n_layers - 1}"][-1]
    
    vmax = np.max([np.abs(W_bottom).max(), np.abs(b_bottom).max()])

    H0, D = W_bottom.shape
    assert int(np.sqrt(D)) ** 2 == int(np.sqrt(D) ** 2), "D is not a perfect square"
    R = int(np.sqrt(D))

    axes = []
    for h in range(H0):
        ax = fig.add_subplot(gs[h // 4, 2 + h % 4])
        axes.append(ax)
        im = ax.imshow(W_bottom[h].reshape(R, R), vmin=-vmax, vmax=vmax, cmap="jet")
        plt.axis("equal")
        plt.axis("off")
        if h == 2:
            ax.set_title("Bottom weights")
    plt.colorbar(im, ax=axes)

    # bottom bias
    ax = fig.add_subplot(gs[:2, 1])
    ax.imshow(b_bottom.reshape(R, R), vmin=-vmax, vmax=vmax, cmap="jet")
    ax.set_title("Bottom bias")
    plt.axis("equal")
    plt.axis("off")


def plot_lines(ax, lines, title, gt_values=None):
    lines = np.atleast_2d(lines)
    for l in lines:
        ax.plot(l)

    ax.set_title(title)

    if gt_values is not None:
        gt_values = np.atleast_1d(gt_values)
        for v in gt_values:
            v = v.item()
            l, = ax.plot((0, lines[0].size), (v, v), "k--")
        l.set_label("ground truth")
        ax.legend()


def make_png(fig, free_energy, theta, ground_truth, until_epoch=None):
    """Return figure with visualization of training."""
    plot_weights(fig, theta)
    plot_means(fig, theta, result_idx=until_epoch)

    gs = GridSpec(nrows=6, ncols=2)

    # make free energy same size as the other lines by prepending an invalid value:
    # currently, TVEM stores the initial values of the learned parameters but not the
    # initial value of the free energy.
    x_lims = (-10, free_energy.size + 1 + 10)
    F_ax = fig.add_subplot(gs[3, 1])
    plot_lines(F_ax, free_energy[:until_epoch], title="Free energy", gt_values=ground_truth["logL"][...])
    F_ax.get_xaxis().set_visible(False)
    F_ax.set_xlim(*x_lims)
    F_ax.set_ylim(np.partition(free_energy, 1)[2:].min() - 10,  free_energy.max() + 10)

    pi_ax = fig.add_subplot(gs[4, 1])
    plot_lines(pi_ax, theta["pies"][:until_epoch].T, r"$\pi_h$", ground_truth["pies"][...])
    pi_ax.get_xaxis().set_visible(False)
    pi_ax.set_xlim(*x_lims)

    # assuming all covariance matrices are equal and all multiple of the identity matrix
    gt_sigma2 = ground_truth["sigma2"][...]
    sigma2_ax = fig.add_subplot(gs[5, 1])
    plot_lines(sigma2_ax, theta["sigma2"][:until_epoch].T, r"$\sigma^2$", gt_sigma2)
    sigma2_ax.set_xlabel("iteration")
    sigma2_ax.set_xlim(*x_lims)
    sigma2_ax.get_yaxis().get_major_formatter().set_useOffset(False)


if __name__ == "__main__":
    args = parse_args()

    data_file = h5py.File(args.data, "r")
    data, labels = data_file["data"][...], data_file["labels"][...]
    ground_truth = data_file["ground_truth"]

    train_out = h5py.File(args.train_output, "r")
    theta = train_out["theta"]
    free_energy = train_out["train_F"][...]

    figsize = plt.rcParams["figure.figsize"]
    figsize[0] *= 2
    figsize[1] *= 2
    fig = plt.figure(figsize=figsize)

    # PNG
    make_png(fig, free_energy, theta, ground_truth)
    fig.savefig(args.output_dir + "/viz.png")

    # GIF
    #def animation(until_epoch):
    #    # FIXME clearing the figure and recreating the axes every time is slow.
    #    # should change the data inside the axes instead
    #    fig.clear()
    #    make_png(fig, free_energy, theta, ground_truth, until_epoch=until_epoch)

    #n_epochs = free_energy.size
    #anim = FuncAnimation(
    #    fig, animation, range(0, n_epochs + 1, args.gif_step), interval=500, repeat_delay=2000
    #)
    #anim.save(args.output_dir + "/viz.gif", writer="imagemagick")
